package main

import (
	"bufio"
	"encoding/binary"
	"fmt"
	"github.com/cilium/ebpf"
	"io"
	"log"
	"net"
	"os"
	"os/exec"
	"os/signal"
	"strconv"
	"sync"
	"syscall"
)

//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -output-stem src -cflags $BPF_CFLAGS bpf tc_neigh_src.c -- -I../headers
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -output-stem des -cflags $BPF_CFLAGS bpf tc_neigh_des.c -- -I../headers

type backend struct {
	daddr    uint32
	saddr    uint32
	ifdindex uint16
	ifindex  uint16
}

func main() {
	if len(os.Args) < 3 {
		log.Fatalf("Please specify a main and destination network interface")
	}
	var wg sync.WaitGroup
	wg.Add(1)
	//1.源设备
	ifaceSrcName := os.Args[1]
	ifaceSrc, err := net.InterfaceByName(ifaceSrcName)
	if err != nil {
		log.Fatalf("lookup network iface %q: %s", ifaceSrcName, err)
	}
	//2.目的设备
	ifaceDestName := os.Args[2]
	ifaceDest, err := net.InterfaceByName(ifaceDestName)
	if err != nil {
		log.Fatalf("lookup network ifaceSrc %q: %s", ifaceDestName, err)
	}
	//3.执行tc加载文件
	obType := os.Args[3]
	ifaceName := ifaceSrcName
	if obType == "des" {
		ifaceName = ifaceDestName
	}
	addArgs := []string{ifaceName, "1", obType}
	cmd := exec.Command("./init.sh", addArgs...)
	stdout, _ := cmd.StdoutPipe()
	_ = cmd.Start()
	//4.打印执行输出
	reader := bufio.NewReader(stdout)
	//实时循环读取输出流中的一行内容
	out := ""
	for {
		line, err2 := reader.ReadString('\n')
		if err2 != nil || io.EOF == err2 {
			break
		}
		out += line
	}
	fmt.Println("shell out:" + out)
	//5.获取返回码
	if err := cmd.Wait(); err != nil {
		exitError, ok := err.(*exec.ExitError)
		if ok {
			status := exitError.ExitCode()
			fmt.Println("status:" + strconv.Itoa(status))
		}
	}
	baseDir := "/sys/fs/bpf/tc/globals/"
	mapFile := ""
	desAddr := uint32(0)
	srcAddr := uint32(0)
	srcKey := uint32(0)
	ifdIndex := uint16(0)
	ifIndex := uint16(0)
	//6.初始化 ebpf map 数据
	if obType == "src" {
		mapFile = "backends_srcs"
		srcAddr = ip2int("192.168.0.21")
		desAddr = ip2int("192.168.0.22")
		srcKey = ip2int("192.168.0.22")
		ifIndex = uint16(ifaceSrc.Index)
		ifdIndex = uint16(ifaceDest.Index)
	} else {
		mapFile = "backends_dess"
		srcAddr = ip2int("192.168.0.22")
		desAddr = ip2int("192.168.0.21")
		srcKey = ip2int("192.168.0.21")
		ifIndex = uint16(ifaceDest.Index)
		ifdIndex = uint16(ifaceSrc.Index)
	}
	srcBack, err := ebpf.LoadPinnedMap(baseDir+mapFile, nil)
	if err != nil {
		log.Fatalf("creating perf hash array: %s", err)
	}
	defer srcBack.Close()

	b := backend{
		daddr:    desAddr,
		saddr:    srcAddr,
		ifdindex: ifdIndex,
		ifindex:  ifIndex,
	}

	if err := srcBack.Update(srcKey, b, ebpf.UpdateAny); err != nil {
		fmt.Println(err.Error())
		os.Exit(1)
	}

	//7.注册退出时执行的函数
	deferFunc(func() {
		defer wg.Done()
		removeArgs := []string{ifaceSrcName, obType}
		removeCmd := exec.Command("./remove.sh", removeArgs...)
		stdouta, _ := removeCmd.StdoutPipe()
		if err != nil {
			log.Println(err)
		}
		e := removeCmd.Start()
		if e != nil {
			log.Println(e)
		}
		reader := bufio.NewReader(stdouta)
		//实时循环读取输出流中的一行内容
		out := ""
		for {
			line, err2 := reader.ReadString('\n')
			if err2 != nil || io.EOF == err2 {
				break
			}
			out += line
		}
		fmt.Println(out)
	})
	//8.注册信号处理函数
	sig := make(chan os.Signal)
	signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
	<-sig
	//9. 等待协程退出
	wg.Wait()

}
func deferFunc(f func()) {
	s := make(chan os.Signal, 1)
	signal.Notify(s, os.Interrupt, syscall.SIGTERM)
	go func() {
		<-s
		f()
		os.Exit(1)
	}()
}

func ip2int(ip string) uint32 {
	ipaddr := net.ParseIP(ip)
	return binary.LittleEndian.Uint32(ipaddr.To4())
}
